import random
import numpy as np
import os
import logging
import sys
import torch
from tqdm import tqdm
from nats_bench import create
import time

__all__ = ["RandomFinderNASBench201"]


class ArchManager:
    """Architecture manager for NAS-Bench-201 architecture representation and operations"""
    def __init__(self):
        self.operations = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
        self.num_ops = len(self.operations)
        self.num_edges = 6  # Number of edges in NAS-Bench-201

    def random_sample(self):
        """Randomly generate a valid NAS-Bench-201 architecture"""
        edge_ops = [random.randint(0, self.num_ops - 1) for _ in range(self.num_edges)]
        return self._ops_to_arch_str(edge_ops)

    def _ops_to_arch_str(self, ops):
        """Convert operation index list to architecture string"""
        if len(ops) != self.num_edges:
            raise ValueError(f"Expected {self.num_edges} operations, got {len(ops)}")
        
        # Build architecture string
        arch_str = f"|{self.operations[ops[0]]}~0|+|{self.operations[ops[1]]}~0|{self.operations[ops[2]]}~1|+|{self.operations[ops[3]]}~0|{self.operations[ops[4]]}~1|{self.operations[ops[5]]}~2|"
        return arch_str


class AccuracyPredictor:
    """Accuracy predictor using NAS-Bench-201 API to query architecture performance"""
    def __init__(self, dataset='cifar10', metric='test'):
        try:
            self.api = create(None, 'tss', fast_mode=True, verbose=False)
            self.dataset = dataset
            self.metric = metric

            print(f"Successfully loaded NATS-Bench TSS API for {dataset} with {metric} metric")
        except Exception as e:
            print(f"Unable to load NATS-Bench TSS API: {e}")
            print("Please ensure you have downloaded the NATS-Bench TSS data file and set the correct path")
            sys.exit(1)

    def predict_accuracy(self, arch):
        """Predict accuracy for a single architecture"""
        try:
            arch_index = self.api.query_index_by_arch(arch)
            if self.dataset == 'cifar10' and self.metric == 'valid':
                results = self.api.get_more_info(arch_index, 'cifar10-valid', hp='200', is_random=False)
            else:
                results = self.api.get_more_info(arch_index, self.dataset, hp='200', is_random=False)
                
            # Return accuracy based on the specified metric
            if self.metric == 'valid':
                return results['valid-accuracy']
            else:
                return results['test-accuracy']
        except Exception as e:
            print(f"Error querying architecture performance: {e}")
            return 0.0  # Return low accuracy for failed queries


class RandomFinderNASBench201:
    """Simple random search algorithm for NAS-Bench-201 search space"""
    def __init__(
        self,
        dataset='cifar10',
        logger=None,
        **kwargs
    ):
        self.dataset = dataset
        self.arch_manager = ArchManager()
        self.metric = kwargs.get("metric", "test")
        self.accuracy_predictor = AccuracyPredictor(dataset, self.metric)
        self.logger = logger
        
        # Random search parameters
        self.seed = kwargs.get("seed", 0)
        # Add maximum exploration parameter
        self.max_explored = kwargs.get("max_explored", 5000)
        
        # Track visited architectures and cache results
        self.visited = set()  # Store visited architectures
        self.arch_performances = {}  # Cache architecture performance to avoid redundant queries
        
        # Track search progress
        self.total_explored = 0  # Total architectures explored
        self.best_arch = None  # Best architecture found
        self.best_acc = 0.0  # Best accuracy
        self.best_found_at = 0  # Sample number when best architecture was found
        
        # Set random seed
        if self.seed is not None:
            random.seed(self.seed)
            np.random.seed(self.seed)
            torch.manual_seed(self.seed)
            torch.cuda.manual_seed_all(self.seed)
            print(f"Random seed set to: {self.seed}")

    def evaluate_arch(self, arch_str):
        """Evaluate a single architecture and update best architecture if needed"""
        if arch_str not in self.arch_performances:
            acc = self.accuracy_predictor.predict_accuracy(arch_str)
            self.arch_performances[arch_str] = acc
        else:
            acc = self.arch_performances[arch_str]
            
        self.total_explored += 1
        
        # Update best architecture
        if acc > self.best_acc:
            self.best_acc = acc
            self.best_arch = arch_str
            self.best_found_at = self.total_explored
            self.logger.info(f"New best architecture found! Acc: {acc:.4f}, Sample: {self.total_explored}")
            self.logger.info(f"Architecture: {arch_str}")
                
        return acc

    def run_random_search(self):
        """Run random search to find optimal architecture"""
        self.logger.info(f"Starting random search for NAS-Bench-201 on {self.dataset} with {self.metric} metric...")
        self.logger.info(f"Search will stop after exploring {self.max_explored} architectures")
        
        # Track best accuracy over time
        best_valids = [0.0]
        
        # Use fixed number of iterations instead of target accuracy
        with tqdm(total=self.max_explored) as pbar:
            while self.total_explored < self.max_explored:
                # Generate a random architecture
                while True:
                    arch = self.arch_manager.random_sample()
                    if arch not in self.visited:
                        self.visited.add(arch)
                        break
                
                # Evaluate the architecture
                acc = self.evaluate_arch(arch)
                pbar.update(1)
                
                # Log progress every 100 samples
                if self.total_explored % 100 == 0:
                    self.logger.info(f"Architectures explored: {self.total_explored}/{self.max_explored}")
                    self.logger.info(f"Current best accuracy: {self.best_acc:.4f}")
                    pbar.set_description(f"Best: {self.best_acc:.4f}")
                
                # Record best accuracy every 50 samples
                if self.total_explored % 50 == 0:
                    best_valids.append(self.best_acc)
        
        # Output final results
        self.logger.info(f"\nRandom search completed!")
        self.logger.info(f"Total explored architectures: {self.total_explored}")
        self.logger.info(f"Best architecture found at sample #{self.best_found_at} (after {self.best_found_at/self.total_explored:.2%} of total):")
        self.logger.info(f"Accuracy: {self.best_acc:.4f}")
        self.logger.info(f"Architecture: {self.best_arch}")
            
        return best_valids, [self.best_acc, self.best_arch], self.best_found_at


def main():
    """Main function: parse command line arguments and run random search"""
    import argparse
    parser = argparse.ArgumentParser(description='Search for high-performance architectures in NAS-Bench-201 using random search')
    parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'cifar100', 'ImageNet16-120'],
                        help='Dataset to optimize architecture for')
    parser.add_argument('--metric', type=str, default='test', choices=['test', 'valid'],
                        help='Which metric to use for evaluation (test or validation accuracy)')
    parser.add_argument('--seed', type=int, default=None, help='Random seed for reproducibility')
    parser.add_argument('--max_explored', type=int, default=100, help='Maximum number of architectures to explore')
    args = parser.parse_args()
    
    # Configure logging
    current_file = os.path.basename(__file__).split('.')[0]
    log_dir = f"search_logs/{current_file}/{args.dataset}-{args.metric}"
    os.makedirs(log_dir, exist_ok=True)
    # Add timestamp to log filename
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(os.path.join(log_dir, f"seed{args.seed}_{args.dataset}_{args.metric}_max{args.max_explored}_{timestamp}.log")),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger("NAS-Random")
    
    # Create random search instance
    finder = RandomFinderNASBench201(
        dataset=args.dataset,
        logger=logger,
        seed=args.seed,
        metric=args.metric,
        max_explored=args.max_explored
    )
    
    # Run random search
    logger.info(f"Starting random search on {args.dataset} with {args.metric} metric")
    logger.info(f"Max architectures to explore: {args.max_explored}")
    best_valids, best_info, best_found_at = finder.run_random_search()
    
    # Output final result
    logger.info("\nBest architecture:")
    logger.info(f"Best architecture found at sample: #{best_found_at}")
    logger.info(f"Architecture string: {best_info[1]}")
    logger.info(f"Accuracy on {args.dataset} ({args.metric}): {best_info[0]:.4f}")


if __name__ == "__main__":
    main()
